LeJEPA Implemetation

Implementation of LeJEPA for text
research-paper
JEPA
Author

Andrew Boney

Published

January 16, 2026

LeJEPA Implementation for Text

As a follow up to my previous blog post, I want to implement a LeJEPA architecture adapted for text. The original LeJEPA paper focuses on vision tasks, using a latent prediction approach where the model learns to predict representations of masked image patches from visible context.

In this post, I’ll adapt these principles to text by implementing a transformer-based encoder that predicts latent representations of masked text spans. The key idea is to use span masking - masking contiguous chunks of text rather than random tokens - which encourages the model to learn higher-level semantic representations.

Approach

I’ll implement a LeJEPA-style architecture with:

  • An encoder that processes visible text tokens.
  • A predictor that learns to predict the target encoder’s representations of masked spans.

The goal is to train this model from scratch on a text dataset and evaluate how well the learned representations capture semantic information. This is primarily an implementation exercise to better understand how the LeJEPA architecture works and how it can be adapted beyond vision tasks.

For this we’ll need to create functions and classes for: - A LeJEPAText class that defines the LeJEPA architecture. - A LeJEPATextDataModule class that loads and preprocesses the text data.

And use: - A text dataset, for which I’ll use a standard dataset from HuggingFace. - A tokenizer, for which I’ll use a standard open source tokenizer.

Data

Investigating the Dataset

prompt / response

What would be a good dataset to use for this? Please search the web

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠

For a text LeJEPA self-supervised learning experiment, here are some good dataset options:

Best for this experiment:

  1. FineWeb / FineWeb-Edu - * Includes 15T tokens of cleaned and deduplicated English web data. * The dataset focuses on cleaner, more informative content that actually helps models learn language patterns, reasoning, and structure. That is why it has become a popular choice for pretraining and fine-tuning large language models.

  2. WikiText-103 - A classic choice for self-supervised text experiments. Clean Wikipedia articles with ~100M tokens - manageable size for experimentation.

  3. OpenWebText - Recreated WebText corpus, good for general text pretraining.

My recommendation for your experiment:

For a blog post / learning exercise, I’d suggest WikiText-103 (wikitext-103-raw-v1) because: - Manageable size (~100M tokens) - Clean, well-structured text - Fast to download and iterate with - Standard benchmark for language modeling

from datasets import load_dataset

# Quick to download, good for experimentation
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")

If you want something larger later, you can use HuggingFaceFW/fineweb-edu with streaming:

# For larger scale (streaming to avoid downloading all 15T tokens)
dataset = load_dataset("HuggingFaceFW/fineweb-edu", split="train", streaming=True)

Would you like me to help set up the DataModule for WikiText-103?

investigation

from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
dataset["train"][:5]
{'text': ['',
  ' = Valkyria Chronicles III = \n',
  '',
  ' Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . \n',
  " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series newcomers . Character designer Raita Honjou and composer Hitoshi Sakimoto both returned from previous entries , along with Valkyria Chronicles II director Takeshi Ozawa . A large team of writers handled the script . The game 's opening theme was sung by May 'n . \n"]}

prompt / response

Hmmm… it looks like in this dataset the text is split between documents. How can I deal with this in a dataset for training a text JEPA?

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

Great observation! The standard approach is to concatenate all documents then chunk into fixed-length sequences. This avoids short sequences and wasted padding.

from transformers import AutoTokenizer
from torch.utils.data import Dataset

class ChunkedTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128):
        # Concatenate all non-empty text
        all_text = " ".join([t for t in dataset["text"] if t.strip()])
        
        # Tokenize everything at once
        tokens = tokenizer(all_text, return_tensors="pt", add_special_tokens=False)["input_ids"].squeeze()
        
        # Chunk into fixed-length sequences (drop remainder)
        num_chunks = len(tokens) // max_length
        self.chunks = tokens[:num_chunks * max_length].view(num_chunks, max_length)
    
    def __len__(self):
        return len(self.chunks)
    
    def __getitem__(self, idx):
        return {"input_ids": self.chunks[idx]}

# Example usage
tokenizer = AutoTokenizer.from_pretrained("gpt2")
train_dataset = ChunkedTextDataset(dataset["train"], tokenizer, max_length=128)

print(f"Number of chunks: {len(train_dataset)}")
print(f"Sample shape: {train_dataset[0]['input_ids'].shape}")

Key points: - Concatenate: Join all documents with spaces (ignoring empty rows) - Tokenize once: More efficient than per-document tokenization - Fixed chunks: Every sample is exactly max_length tokens - no padding needed - Drop remainder: Discard the last partial chunk for simplicity

This is the same approach used by most LLM pretraining pipelines. You can also add document separator tokens if you want the model to learn document boundaries.

investigation

from transformers import AutoTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm

class ChunkedTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128):
        # Concatenate all non-empty text
        all_text = " ".join([t for t in tqdm(dataset["text"]) if t.strip()])
        
        # Tokenize everything at once
        tokens = tokenizer(all_text, return_tensors="pt", add_special_tokens=False)["input_ids"].squeeze()
        
        # Chunk into fixed-length sequences (drop remainder)
        num_chunks = len(tokens) // max_length
        self.chunks = tokens[:num_chunks * max_length].view(num_chunks, max_length)
    
    def __len__(self):
        return len(self.chunks)
    
    def __getitem__(self, idx):
        return {"input_ids": self.chunks[idx]}
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Concatenate all non-empty text
all_text = " ".join([t for t in tqdm(dataset["train"]["text"]) if t.strip()])
  0%|          | 0/36718 [00:00<?, ?it/s] 31%|███       | 11260/36718 [00:00<00:00, 112588.56it/s] 62%|██████▏   | 22673/36718 [00:00<00:00, 113488.47it/s] 93%|█████████▎| 34085/36718 [00:00<00:00, 113774.77it/s]100%|██████████| 36718/36718 [00:00<00:00, 113158.42it/s]
# Tokenize everything at once
tokens = tokenizer(all_text[:1000000], return_tensors="pt", add_special_tokens=False)["input_ids"].squeeze()
Token indices sequence length is longer than the specified maximum sequence length for this model (219671 > 1024). Running this sequence through the model will result in indexing errors
len(all_text), len(tokens)
(10916756, 219671)

prompt / response

Can you change the tokenizer above to run a batched way? right now it appears to be crashing on the whole dataset

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

from transformers import AutoTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
import torch

class ChunkedTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128, batch_size=10000):
        # Filter non-empty text
        texts = [t for t in dataset["text"] if t.strip()]
        
        # Tokenize in batches
        all_tokens = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
            batch = texts[i:i + batch_size]
            encoded = tokenizer(batch, add_special_tokens=False)["input_ids"]
            for ids in encoded:
                all_tokens.extend(ids)
        
        tokens = torch.tensor(all_tokens)
        
        # Chunk into fixed-length sequences
        num_chunks = len(tokens) // max_length
        self.chunks = tokens[:num_chunks * max_length].view(num_chunks, max_length)
    
    def __len__(self):
        return len(self.chunks)
    
    def __getitem__(self, idx):
        return {"input_ids": self.chunks[idx]}

# Usage
tokenizer = AutoTokenizer.from_pretrained("gpt2")
train_dataset = ChunkedTextDataset(dataset["train"], tokenizer, max_length=128)
print(f"Number of chunks: {len(train_dataset)}")

Key change: Instead of joining all text into one giant string, we tokenize documents in batches of 10,000 and concatenate the token IDs. This avoids the memory spike from both the huge concatenated string and the single tokenizer call.

investigation

from transformers import AutoTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
import torch

class ChunkedTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128, batch_size=10000):
        # Filter non-empty text
        texts = [t for t in dataset["text"] if t.strip()]
        
        # Tokenize in batches
        all_tokens = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
            batch = texts[i:i + batch_size]
            encoded = tokenizer(batch, add_special_tokens=False)["input_ids"]
            for ids in encoded:
                all_tokens.extend(ids)
        
        tokens = torch.tensor(all_tokens)
        
        # Chunk into fixed-length sequences
        num_chunks = len(tokens) // max_length
        self.chunks = tokens[:num_chunks * max_length].view(num_chunks, max_length)
    
    def __len__(self):
        return len(self.chunks)
    
    def __getitem__(self, idx):
        return self.chunks[idx]

# Usage
tokenizer = AutoTokenizer.from_pretrained("gpt2")
train_dataset = ChunkedTextDataset(dataset["train"], tokenizer, max_length=256)
print(f"Number of chunks: {len(train_dataset)}")
Tokenizing:   0%|          | 0/3 [00:00<?, ?it/s]Tokenizing:  33%|███▎      | 1/3 [00:00<00:00,  2.24it/s]Tokenizing:  67%|██████▋   | 2/3 [00:00<00:00,  2.87it/s]Tokenizing: 100%|██████████| 3/3 [00:00<00:00,  4.11it/s]Tokenizing: 100%|██████████| 3/3 [00:00<00:00,  3.55it/s]
Number of chunks: 9343
train_dataset[0]
tensor([  796,   569, 18354,  7496, 17740,  6711,   796,   220,   198,  2311,
           73, 13090,   645,   569, 18354,  7496,   513,  1058,   791, 47398,
        17740,   357,  4960,  1058, 10545,   230,    99,   161,   254,   112,
         5641, 44444,  9202, 25084, 24440, 12675, 11839,    18,   837,  6578,
          764,   569, 18354,  7496,   286,   262, 30193,   513,  1267,   837,
         8811,  6412,   284,   355,   569, 18354,  7496, 17740,  6711,  2354,
         2869,   837,   318,   257, 16106,  2597,  2488,    12,    31,  2712,
         2008,   983,  4166,   416, 29490,   290,  6343,    13, 44206,   329,
          262, 14047, 44685,   764, 28728,   287,  3269,  2813,   287,  2869,
          837,   340,   318,   262,  2368,   983,   287,   262,   569, 18354,
         7496,  2168,   764, 12645,   278,   262,   976, 21748,   286, 16106,
          290,  1103,  2488,    12,    31,   640, 11327,   355,   663, 27677,
          837,   262,  1621,  4539, 10730,   284,   262,   717,   983,   290,
         5679,   262,   366, 17871,  5321,   366,   837,   257, 23634,  2422,
         4326,  7351,   262,  3277,   286,  7096,   544,  1141,   262,  5498,
         1898,  6839,  1810,   508,  1620,  3200,  2042,  4560,   290,   389,
        46852,  1028,   262, 11773,  4326,   366,  2199,   321,   265,    88,
        12552,   366,   764,   220,   198,   383,   983,  2540,  2478,   287,
         3050,   837,  6872,   625,   257,  1588,  6903,   286,   262,   670,
         1760,   319,   569, 18354,  7496, 17740,  2873,   764,  2893,   340,
        17383,   262,  3210,  3033,   286,   262,  2168,   837,   340,   635,
        25289,  3294, 16895,   837,   884,   355,  1642,   262,   983,   517,
        43486,   329,  2168, 29661,   764, 15684, 11915,   371,  4548,    64,
         8835,    73,   280,   290, 26777,  7286, 13704, 13231, 43354,  1111,
         4504,   422,  2180, 12784,   837,  1863,   351,   569, 18354,  7496,
        17740,  2873,  3437, 33687,  5303, 18024])

Data Loading and Preprocessing

Okay so I’ve got a dataset and a tokenizer. However, I think we need some further changes to the dataset, so that we can use it for training. To apply JEPA here, we need a “source” batch of tokens and a “target” batch of tokens, where we train the joint representations to be equivalent.

I’m thinking a good way to do this to split texts in the tokenizer __getitem__ method. This will probably need a min_split_size and max_split_size to avoid splits being overly large or small.

prompt / response

Is the approach that I’m suggesting above sensible? Is there anything I might be missing?

One thing I’m wondering is if having different splits as inputs to the model will have an impact on positional encoding.

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

Your intuition about positional encoding is spot on - this is a critical consideration!

The issue: If you just split a sequence and feed source/target as separate sequences starting from position 0, you lose the positional relationship. The model won’t know where in the original sequence the target spans were.

How JEPA handles this (vision analogy): - Context encoder: visible patches with original position embeddings - Target encoder: masked patches with original position embeddings
- Predictor: takes context embeddings + target positions → predicts target representations

For text, I’d suggest this structure:

def __getitem__(self, idx):
    tokens = self.chunks[idx]  # Full sequence [seq_len]
    
    # Generate span mask (True = masked/target)
    mask = self.generate_span_mask(len(tokens))
    
    return {
        "input_ids": tokens,           # Full sequence
        "target_mask": mask,           # Which positions to predict
        "positions": torch.arange(len(tokens))  # Explicit positions
    }

Then in your model: - Context encoder: processes input_ids[~target_mask] with positions[~target_mask] - Target encoder: processes input_ids[target_mask] with positions[target_mask] - Predictor: uses context output + positions[target_mask] to predict target representations

Key insight: The predictor needs to receive the target positions as input (often via learned position queries or just concatenated position embeddings) so it knows where to predict.

Does this align with how you were thinking about it, or would you like me to sketch out the span masking logic?

prompt / response

Yeah, that makes sense. That being said I don’t think we need the positional key here, as that can be generated in the model.

Can you start by sketching out the generate_span_mask logic? I think I want this so that it splits into 2 spans at a random index, with a min length to avoid splits that are too small for the model to learn.

Try and make this minimal… I essentially just want a boolean mask based on a random index between min_length and seq_len - min_length

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

import torch
import random

def generate_span_mask(seq_len, min_length=32):
    """Split sequence into context (False) and target (True) at random index."""
    split_idx = random.randint(min_length, seq_len - min_length)
    mask = torch.zeros(seq_len, dtype=torch.bool)
    mask[split_idx:] = True  # Second half is target
    return mask

You could also randomly flip which half is target:

def generate_span_mask(seq_len, min_length=32):
    split_idx = random.randint(min_length, seq_len - min_length)
    mask = torch.zeros(seq_len, dtype=torch.bool)
    mask[split_idx:] = True
    if random.random() > 0.5:
        mask = ~mask  # Flip so first half is target
    return mask

prompt / response

I like version one of this. Can you now integrate this into the dataset logic.

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

import torch
import random
from torch.utils.data import Dataset
from tqdm import tqdm

class ChunkedTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=256, batch_size=10000, min_span_length=32):
        self.min_span_length = min_span_length
        
        # Filter non-empty text
        texts = [t for t in dataset["text"] if t.strip()]
        
        # Tokenize in batches
        all_tokens = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
            batch = texts[i:i + batch_size]
            encoded = tokenizer(batch, add_special_tokens=False)["input_ids"]
            for ids in encoded:
                all_tokens.extend(ids)
        
        tokens = torch.tensor(all_tokens)
        
        # Chunk into fixed-length sequences
        num_chunks = len(tokens) // max_length
        self.chunks = tokens[:num_chunks * max_length].view(num_chunks, max_length)
    
    def generate_span_mask(self, seq_len):
        split_idx = random.randint(self.min_span_length, seq_len - self.min_span_length)
        mask = torch.zeros(seq_len, dtype=torch.bool)
        mask[split_idx:] = True
        return mask
    
    def __len__(self):
        return len(self.chunks)
    
    def __getitem__(self, idx):
        tokens = self.chunks[idx]
        mask = self.generate_span_mask(len(tokens))
        return {
            "input_ids": tokens,
            "target_mask": mask,
        }
import torch
import random
from torch.utils.data import Dataset
from tqdm import tqdm

class ChunkedTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=256, batch_size=10000, min_span_length=32):
        self.min_span_length = min_span_length
        
        # Filter non-empty text
        texts = [t for t in dataset["text"] if t.strip()]
        
        # Tokenize in batches
        all_tokens = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
            batch = texts[i:i + batch_size]
            encoded = tokenizer(batch, add_special_tokens=False)["input_ids"]
            for ids in encoded:
                all_tokens.extend(ids)
        
        tokens = torch.tensor(all_tokens)
        
        # Chunk into fixed-length sequences
        num_chunks = len(tokens) // max_length
        self.chunks = tokens[:num_chunks * max_length].view(num_chunks, max_length)
    
    def generate_span_mask(self, seq_len):
        split_idx = random.randint(self.min_span_length, seq_len - self.min_span_length)
        mask = torch.zeros(seq_len, dtype=torch.bool)
        mask[split_idx:] = True
        return mask
    
    def __len__(self):
        return len(self.chunks)
    
    def __getitem__(self, idx):
        tokens = self.chunks[idx]
        mask = self.generate_span_mask(len(tokens))
        return {
            "input_ids": tokens,
            "target_mask": mask,
        }

# Usage
tokenizer = AutoTokenizer.from_pretrained("gpt2")
train_dataset = ChunkedTextDataset(dataset["train"], tokenizer, max_length=256)
print(f"Number of chunks: {len(train_dataset)}")
print(f"Sample:", train_dataset[0])
Tokenizing:   0%|          | 0/3 [00:00<?, ?it/s]Tokenizing:  33%|███▎      | 1/3 [00:00<00:00,  2.00it/s]Tokenizing:  67%|██████▋   | 2/3 [00:01<00:00,  1.92it/s]Tokenizing: 100%|██████████| 3/3 [00:01<00:00,  2.91it/s]Tokenizing: 100%|██████████| 3/3 [00:01<00:00,  2.56it/s]
Number of chunks: 9343
Sample: {'input_ids': tensor([  796,   569, 18354,  7496, 17740,  6711,   796,   220,   198,  2311,
           73, 13090,   645,   569, 18354,  7496,   513,  1058,   791, 47398,
        17740,   357,  4960,  1058, 10545,   230,    99,   161,   254,   112,
         5641, 44444,  9202, 25084, 24440, 12675, 11839,    18,   837,  6578,
          764,   569, 18354,  7496,   286,   262, 30193,   513,  1267,   837,
         8811,  6412,   284,   355,   569, 18354,  7496, 17740,  6711,  2354,
         2869,   837,   318,   257, 16106,  2597,  2488,    12,    31,  2712,
         2008,   983,  4166,   416, 29490,   290,  6343,    13, 44206,   329,
          262, 14047, 44685,   764, 28728,   287,  3269,  2813,   287,  2869,
          837,   340,   318,   262,  2368,   983,   287,   262,   569, 18354,
         7496,  2168,   764, 12645,   278,   262,   976, 21748,   286, 16106,
          290,  1103,  2488,    12,    31,   640, 11327,   355,   663, 27677,
          837,   262,  1621,  4539, 10730,   284,   262,   717,   983,   290,
         5679,   262,   366, 17871,  5321,   366,   837,   257, 23634,  2422,
         4326,  7351,   262,  3277,   286,  7096,   544,  1141,   262,  5498,
         1898,  6839,  1810,   508,  1620,  3200,  2042,  4560,   290,   389,
        46852,  1028,   262, 11773,  4326,   366,  2199,   321,   265,    88,
        12552,   366,   764,   220,   198,   383,   983,  2540,  2478,   287,
         3050,   837,  6872,   625,   257,  1588,  6903,   286,   262,   670,
         1760,   319,   569, 18354,  7496, 17740,  2873,   764,  2893,   340,
        17383,   262,  3210,  3033,   286,   262,  2168,   837,   340,   635,
        25289,  3294, 16895,   837,   884,   355,  1642,   262,   983,   517,
        43486,   329,  2168, 29661,   764, 15684, 11915,   371,  4548,    64,
         8835,    73,   280,   290, 26777,  7286, 13704, 13231, 43354,  1111,
         4504,   422,  2180, 12784,   837,  1863,   351,   569, 18354,  7496,
        17740,  2873,  3437, 33687,  5303, 18024]), 'target_mask': tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True])}

prompt / response

I’m now happy with how this is splitting and chunking.

Can you create a pytorch lightning data module for processing this data?

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

class LeJEPATextDataModule(pl.LightningDataModule):
    def __init__(
        self,
        tokenizer_name="gpt2",
        max_length=256,
        min_span_length=32,
        batch_size=32,
        num_workers=0,
    ):
        super().__init__()
        self.tokenizer_name = tokenizer_name
        self.max_length = max_length
        self.min_span_length = min_span_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        
    def setup(self, stage=None):
        tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
        
        self.train_dataset = ChunkedTextDataset(
            dataset["train"], tokenizer, self.max_length, min_span_length=self.min_span_length
        )
        self.val_dataset = ChunkedTextDataset(
            dataset["validation"], tokenizer, self.max_length, min_span_length=self.min_span_length
        )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

Usage:

dm = LeJEPATextDataModule(batch_size=32)
dm.setup()
print(f"Train batches: {len(dm.train_dataloader())}")
print(f"Val batches: {len(dm.val_dataloader())}")

# Check a batch
batch = next(iter(dm.train_dataloader()))
print(f"input_ids shape: {batch['input_ids'].shape}")
print(f"target_mask shape: {batch['target_mask'].shape}")
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

class LeJEPATextDataModule(pl.LightningDataModule):
    def __init__(
        self,
        tokenizer_name="gpt2",
        max_length=256,
        min_span_length=32,
        batch_size=32,
        num_workers=0,
    ):
        super().__init__()
        self.tokenizer_name = tokenizer_name
        self.max_length = max_length
        self.min_span_length = min_span_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        
    def setup(self, stage=None):
        tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
        
        self.train_dataset = ChunkedTextDataset(
            dataset["train"], tokenizer, self.max_length, min_span_length=self.min_span_length
        )
        self.val_dataset = ChunkedTextDataset(
            dataset["validation"], tokenizer, self.max_length, min_span_length=self.min_span_length
        )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )
dm = LeJEPATextDataModule(batch_size=32)
dm.setup()
print(f"Train batches: {len(dm.train_dataloader())}")
print(f"Val batches: {len(dm.val_dataloader())}")

# Check a batch
batch = next(iter(dm.train_dataloader()))
print(f"input_ids shape: {batch['input_ids'].shape}")
print(f"target_mask shape: {batch['target_mask'].shape}")
Tokenizing:   0%|          | 0/3 [00:00<?, ?it/s]Tokenizing:  33%|███▎      | 1/3 [00:00<00:00,  3.54it/s]Tokenizing:  67%|██████▋   | 2/3 [00:00<00:00,  2.36it/s]Tokenizing: 100%|██████████| 3/3 [00:00<00:00,  3.53it/s]Tokenizing: 100%|██████████| 3/3 [00:00<00:00,  3.26it/s]
Tokenizing:   0%|          | 0/1 [00:00<?, ?it/s]Tokenizing: 100%|██████████| 1/1 [00:00<00:00, 15.44it/s]
Train batches: 292
Val batches: 31
input_ids shape: torch.Size([32, 256])
target_mask shape: torch.Size([32, 256])

Model

Next, we need to define the model. There’s a few things to work out here: - How to apply SIGReg. - Defining an encoder that can encoder both the source and target spans. - Defining a predictor that can predict the target encoder’s representations of masked spans.

SIGReg

prompt / response

To start, I want to work out how to implement SIGReg. Can you implement this, referencing both the paper and this implementation on github. If easier I’m happy for you to use the github implementation directly.

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

read_url(url=https://github.com/galilai-group/lejepa)
{
  "id": "toolu_01JvPDFMtq6kKnbg8ei3JZ1r",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://github.com/galilai-group/lejepa"
    }
  },
  "result": "'# galilai-group/lejepa\\n\\n\\n## / Files\\n- LICENSE\\n- MINIMAL.md\\n- README.md\\n- \ud83d\udcc1 eval\\n- \ud83d\udcc1 figures\\n- \ud83d\udcc1 lejepa\\n- pyproject.toml\\n- \ud83d\udcc1 scripts\\n- setup.py\\n- \ud83d\udcc1 tests\\n\\n## README\\n# LeJEPA\\n**Lean Joint-Embedding Predictive Architecture (LeJEPA): Provable and Scalable Self-Supervised Learning Without the Heuristics**\\n[GitHub Repository](https://github.com/rbalestr-lab/lejepa)  \\n[arXiv:2511.08544](https://arxiv.org/abs/2511.08544)\\n---\\n\\nRush to our [minimal working example](MINIMAL.md) to see a full-fledge working example (ViT, inet).\\n\\n## Demo\\n\\n<img src=\"eval/output1.gif\" controls widt<TRUNCATED>"
}
read_url(url=https://github.com/galilai-group/lejepa/tree/main/…)
{
  "id": "toolu_01QMF2DitEf7Q3dKfLskhjh6",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://github.com/galilai-group/lejepa/tree/main/lejepa"
    }
  },
  "result": "'# galilai-group/lejepa\\n\\n\\n## /lejepa Files\\n- __init__.py\\n- \ud83d\udcc1 __pycache__\\n- \ud83d\udcc1 multivariate\\n- \ud83d\udcc1 univariate'"
}
read_url(url=https://raw.githubusercontent.com/galilai-group/le…)
{
  "id": "toolu_01GAx4ckTnfjWD7PUrfBtxwq",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://raw.githubusercontent.com/galilai-group/lejepa/main/lejepa/__init__.py"
    }
  },
  "result": "'from . import univariate\\nfrom . import multivariate\\n\\n__all__ = [\"univariate\", \"multivariate\"]\\n'"
}
read_url(url=https://raw.githubusercontent.com/galilai-group/le…)
{
  "id": "toolu_01JhuonZnZtXiXAxVRoNTnpk",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://raw.githubusercontent.com/galilai-group/lejepa/main/lejepa/multivariate/__init__.py"
    }
  },
  "result": "'from .slicing import SlicingUnivariateTest\\nfrom .bhep import BHEP\\nfrom .bhep_m import BHEP_M\\nfrom .comb import COMB\\nfrom .hz import HZ\\nfrom .hv import HV\\n\\n__all__ = [SlicingUnivariateTest, BHEP, BHEP_M, COMB, HZ, HV]\\n'"
}
read_url(url=https://raw.githubusercontent.com/galilai-group/le…)
{
  "id": "toolu_01Y2geh4fdasWQbs9gvHY6CL",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://raw.githubusercontent.com/galilai-group/lejepa/main/lejepa/univariate/__init__.py"
    }
  },
  "result": "'from .anderson_darling import AndersonDarling\\nfrom .cramer_von_mises import CramerVonMises\\nfrom .entropy import Entropy\\nfrom .shapiro_wilk import ShapiroWilk\\nfrom .watson import Watson\\nfrom .moments import Moments\\nfrom .likelihood import NLL\\nfrom .jarque_bera import ExtendedJarqueBera, VCReg\\nfrom .epps_pulley import EppsPulley, EppsPulley\\nfrom .base import UnivariateTest\\n\\n__all__ = [\\n    AndersonDarling,\\n    CramerVonMises,\\n    Entropy,\\n    ShapiroWilk,\\n    Watson,\\n    NLL,\\n    ExtendedJarqueBera,\\n    VCReg,\\n    EppsPulley,\\n    EppsPulley,\\n    UnivariateTest,\\n]\\n'"
}
read_url(url=https://raw.githubusercontent.com/galilai-group/le…)
{
  "id": "toolu_01QoPGfnbfuJgeNZW73th4Xi",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://raw.githubusercontent.com/galilai-group/lejepa/main/lejepa/multivariate/slicing.py"
    }
  },
  "result": "'import torch\\nfrom torch import distributed as dist\\nfrom torch.distributed._functional_collectives import (\\n    all_reduce as functional_all_reduce,\\n)\\n\\n\\ndef all_reduce(x, op=\"AVG\"):\\n    if dist.is_available() and dist.is_initialized():\\n        return functional_all_reduce(x, op.lower(), dist.group.WORLD)\\n    else:\\n        return x\\n\\n\\nclass SlicingUnivariateTest(torch.nn.Module):\\n    \"\"\"\\n    Multivariate distribution test using random slicing and univariate test statistics.\\n    This module extends univariate statistical tests to multivariate data by projecting\\n    samples onto <TRUNCATED>"
}
read_url(url=https://raw.githubusercontent.com/galilai-group/le…)
{
  "id": "toolu_01229WWk1815a1NHpf76bRAY",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://raw.githubusercontent.com/galilai-group/lejepa/main/lejepa/univariate/epps_pulley.py"
    }
  },
  "result": "'import torch\\nfrom .base import UnivariateTest\\nfrom torch import distributed as dist\\nfrom torch.distributed.nn import all_reduce as functional_all_reduce\\nfrom torch.distributed.nn import ReduceOp\\n\\n\\ndef all_reduce(x, op=\"AVG\"):\\n    if dist.is_available() and dist.is_initialized():\\n        op = ReduceOp.__dict__[op.upper()]\\n        return functional_all_reduce(x, op)\\n    else:\\n        return x\\n\\n\\nclass EppsPulley(UnivariateTest):\\n    \"\"\"\\n    Fast Epps-Pulley two-sample test statistic for univariate distributions.\\n\\n    This implementation uses numerical integration over the char<TRUNCATED>"
}

I’ve got the key implementation. The SIGReg loss in LeJEPA works by: 1. Slicing: Project high-dimensional embeddings onto random 1D directions 2. Univariate testing: Apply a statistical test (like Epps-Pulley) to check if projections are standard normal 3. Aggregate: Average the test statistics across all slices

Here’s a simplified implementation you can use directly:

import torch
import torch.nn as nn

class EppsPulley(nn.Module):
    """Epps-Pulley test for univariate normality via characteristic function."""
    
    def __init__(self, t_max=3.0, num_points=17):
        super().__init__()
        assert num_points % 2 == 1
        t = torch.linspace(0, t_max, num_points)
        dt = t_max / (num_points - 1)
        weights = torch.full((num_points,), 2 * dt)
        weights[[0, -1]] = dt
        phi = torch.exp(-0.5 * t ** 2)
        
        self.register_buffer("t", t)
        self.register_buffer("phi", phi)
        self.register_buffer("weights", weights * phi)
    
    def forward(self, x):
        # x: (*, N, K) where N=samples, K=slices
        N = x.size(-2)
        x_t = x.unsqueeze(-1) * self.t  # (*, N, K, num_points)
        
        cos_mean = torch.cos(x_t).mean(-3)  # (*, K, num_points)
        sin_mean = torch.sin(x_t).mean(-3)
        
        err = (cos_mean - self.phi).square() + sin_mean.square()
        return (err @ self.weights) * N  # (*, K)


class SIGReg(nn.Module):
    """Sketched Isotropic Gaussian Regularization loss."""
    
    def __init__(self, num_slices=1024, t_max=3.0, num_points=17):
        super().__init__()
        self.num_slices = num_slices
        self.univariate_test = EppsPulley(t_max, num_points)
    
    def forward(self, embeddings):
        # embeddings: (batch_size, embed_dim)
        # Standardize
        embeddings = (embeddings - embeddings.mean(0)) / (embeddings.std(0) + 1e-6)
        
        # Random projections
        D = embeddings.size(-1)
        A = torch.randn(D, self.num_slices, device=embeddings.device)
        A = A / A.norm(dim=0, keepdim=True)  # Normalize to unit vectors
        
        projected = embeddings @ A  # (batch_size, num_slices)
        
        # Apply univariate test and average
        stats = self.univariate_test(projected)  # (num_slices,)
        return stats.mean()

Usage:

sigreg = SIGReg(num_slices=1024)
embeddings = torch.randn(256, 768)  # batch of embeddings
loss = sigreg(embeddings)
print(f"SIGReg loss: {loss.item():.4f}")

The key insight: SIGReg encourages embeddings to be isotropic Gaussian distributed. If the embeddings perfectly match N(0, I), the loss is zero. This prevents collapse without needing stop-gradient or EMA tricks — though for JEPA you’ll still use EMA for the target encoder to provide stable prediction targets.

import torch
import torch.nn as nn

class EppsPulley(nn.Module):
    """Epps-Pulley test for univariate normality via characteristic function."""
    
    def __init__(self, t_max=3.0, num_points=17):
        super().__init__()
        assert num_points % 2 == 1
        t = torch.linspace(0, t_max, num_points)
        dt = t_max / (num_points - 1)
        weights = torch.full((num_points,), 2 * dt)
        weights[[0, -1]] = dt
        phi = torch.exp(-0.5 * t ** 2)
        
        self.register_buffer("t", t)
        self.register_buffer("phi", phi)
        self.register_buffer("weights", weights * phi)
    
    def forward(self, x):
        # x: (*, N, K) where N=samples, K=slices
        N = x.size(-2)
        x_t = x.unsqueeze(-1) * self.t  # (*, N, K, num_points)
        
        cos_mean = torch.cos(x_t).mean(-3)  # (*, K, num_points)
        sin_mean = torch.sin(x_t).mean(-3)
        
        err = (cos_mean - self.phi).square() + sin_mean.square()
        return (err @ self.weights) * N  # (*, K)


class SIGReg(nn.Module):
    """Sketched Isotropic Gaussian Regularization loss."""
    
    def __init__(self, num_slices=1024, t_max=3.0, num_points=17):
        super().__init__()
        self.num_slices = num_slices
        self.univariate_test = EppsPulley(t_max, num_points)
    
    def forward(self, embeddings):
        # embeddings: (batch_size, embed_dim)
        # Standardize
        embeddings = (embeddings - embeddings.mean(0)) / (embeddings.std(0) + 1e-6)
        
        # Random projections
        D = embeddings.size(-1)
        A = torch.randn(D, self.num_slices, device=embeddings.device)
        A = A / A.norm(dim=0, keepdim=True)  # Normalize to unit vectors
        
        projected = embeddings @ A  # (batch_size, num_slices)
        
        # Apply univariate test and average
        stats = self.univariate_test(projected)  # (num_slices,)
        return stats.mean()
sigreg = SIGReg(num_slices=1024)
embeddings = torch.randn(256, 768)  # batch of embeddings
loss = sigreg(embeddings)
print(f"SIGReg loss: {loss.item():.4f}")
SIGReg loss: 0.5811

Encoder Model

prompt / response

Thanks. Now, how can we build the encoder model. I want a transformer based text encoder that can get representations of the source and targets, with masking based on the source/target split. We’ll also need a predictor that predicts the target from the context?

I initially was under the impression that the target encoder would be a copy of the source encoder that is updated as a moving average. However it appears from re-reading the LeJEPA paper that SIGReg supposedly removes the need for EMA. I can also see that other and other JEPA implementations, e.g. this repo, aren’t using stop gradients. I’m going to go ahead using the same encoder for source and target, however I will check for collapse (i.e. encoder producing constant vectors), as this is known to be an issue with JEPA architectures.

Because I’m working in an environment with limited compute, I think it would be better for the predictor just to be a linear layer, rather than a transformer.

Lets do this bit by bit, starting with creating an encoder and showing how we can get source and target representations in a batched way. I think the challenge here is that batches of sources and targets will be inconsistent lengths. If possible I think we should deal with this through passing the mask when we get representations.

Can you use the code here https://github.com/jerber/lang-jepa/blob/main/src/encoder/models.py and here https://github.com/jerber/lang-jepa/blob/main/src/encoder/train.py as inspiration, particularly how the masking is applied. I’m happy for the positional encoding to be a simple embedding layer.

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

read_url(url=https://raw.githubusercontent.com/jerber/lang-jepa…)
{
  "id": "toolu_01GeQxv4TmpEVg9dA2tdq4dY",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://raw.githubusercontent.com/jerber/lang-jepa/main/src/encoder/models.py"
    }
  },
  "result": "'import torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nfrom torch import Tensor\\nfrom transformers import AutoConfig, AutoModel\\n\\nfrom src.common.config import LANGJEPAConfig\\n\\n\\nclass TextTransformer(nn.Module):\\n    \"\"\"Text encoder based on pre-trained transformer models.\"\"\"\\n\\n    def __init__(self, config: LANGJEPAConfig):\\n        super().__init__()\\n        # Load base model config and update with our settings\\n        model_config = AutoConfig.from_pretrained(config.data.tokenizer_path)\\n        model_config.update(\\n            {\\n                \"hidden_size\": config<TRUNCATED>"
}
read_url(url=https://raw.githubusercontent.com/jerber/lang-jepa…)
{
  "id": "toolu_01NzJPyix64rpADsCFQLPwHb",
  "call": {
    "function": "read_url",
    "arguments": {
      "url": "https://raw.githubusercontent.com/jerber/lang-jepa/main/src/encoder/train.py"
    }
  },
  "result": "'import os\\nimport time\\nfrom pathlib import Path\\n\\nimport torch\\nimport torch.nn.functional as F\\nfrom dotenv import load_dotenv\\nfrom torch.utils.data import DataLoader\\n\\nimport wandb\\nfrom src.common.config import LANGJEPAConfig\\nfrom src.common.datasets.fineweb_edu import TextDataset, worker_init_fn\\nfrom src.common.logging import AverageMeter, CSVLogger\\nfrom src.encoder.collator import Batch, Collator\\nfrom src.encoder.models import TextPredictor, TextTransformer\\nfrom src.encoder.utils.helper import init_optimizer, load_checkpoint, save_checkpoint\\nfrom src.encoder.utils.monitor impor<TRUNCATED>"
}

I see - the lang-jepa repo uses torch.no_grad() on the target encoder (same weights, just no gradient flow) and passes attention masks to handle variable lengths. Here’s a simplified encoder setup for your use case:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class TextEncoder(nn.Module):
    """Simple transformer encoder for text."""
    
    def __init__(self, vocab_size, embed_dim=256, num_heads=4, num_layers=4, max_length=512, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Token and position embeddings
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_length, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, input_ids, attention_mask=None):
        """
        Args:
            input_ids: (batch, seq_len)
            attention_mask: (batch, seq_len) - 1 for valid tokens, 0 for masked
        Returns:
            (batch, seq_len, embed_dim)
        """
        B, L = input_ids.shape
        positions = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
        
        x = self.token_embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)
        
        # Convert attention_mask to transformer format (True = ignore)
        if attention_mask is not None:
            src_key_padding_mask = (attention_mask == 0)
        else:
            src_key_padding_mask = None
        
        x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
        return self.norm(x)
    
    def get_representations(self, input_ids, target_mask, pool="mean"):
        """
        Get separate context and target representations.
        
        Args:
            input_ids: (batch, seq_len)
            target_mask: (batch, seq_len) - True for target positions
            pool: how to aggregate token representations
        """
        # Get full sequence representations
        full_repr = self.forward(input_ids)  # (batch, seq_len, embed_dim)
        
        context_mask = ~target_mask  # True for context positions
        
        if pool == "mean":
            # Masked mean pooling for context
            context_repr = (full_repr * context_mask.unsqueeze(-1)).sum(dim=1)
            context_repr = context_repr / context_mask.sum(dim=1, keepdim=True).clamp(min=1)
            
            # Masked mean pooling for target
            target_repr = (full_repr * target_mask.unsqueeze(-1)).sum(dim=1)
            target_repr = target_repr / target_mask.sum(dim=1, keepdim=True).clamp(min=1)
        else:
            raise ValueError(f"Unknown pooling: {pool}")
        
        return context_repr, target_repr


class LinearPredictor(nn.Module):
    """Simple linear predictor: context repr -> predicted target repr."""
    
    def __init__(self, embed_dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.LayerNorm(embed_dim),
        )
    
    def forward(self, context_repr):
        return self.proj(context_repr)

Usage example:

# Initialize
vocab_size = 50257  # GPT-2 vocab size
encoder = TextEncoder(vocab_size, embed_dim=256, num_heads=4, num_layers=4)
predictor = LinearPredictor(embed_dim=256)

# Forward pass with a batch
batch = train_dataset[0]
input_ids = batch["input_ids"].unsqueeze(0)  # Add batch dim
target_mask = batch["target_mask"].unsqueeze(0)

context_repr, target_repr = encoder.get_representations(input_ids, target_mask)
predicted_repr = predictor(context_repr)

print(f"Context repr: {context_repr.shape}")
print(f"Target repr: {target_repr.shape}")
print(f"Predicted repr: {predicted_repr.shape}")

Key design choices: - Single encoder processes full sequence, then we mask-pool to get context/target representations - This avoids variable-length batching complexity - During training, target_repr is computed with torch.no_grad() (like in lang-jepa) to prevent collapse - The predictor is just a linear layer with LayerNorm as you requested

prompt / response

The mask here is unused in the encoder here which I’m finding a bit confusing. Can you confirm if a TransformerEncoder in pytorch can see ahead if there is no mask… i.e. can value in position 10 see the value in position 11?

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

Yes, TransformerEncoder is bidirectional by default - position 10 can absolutely see position 11 (and all other positions). It uses full self-attention with no causal masking.

The src_key_padding_mask parameter is only for padding - telling the model which positions are padding tokens to ignore. Since your sequences are fixed-length chunks (no padding), the mask is indeed unused.

This is actually correct for JEPA - you want bidirectional attention because: - Context tokens should attend to all other context tokens - Target tokens should attend to all other target tokens - You’re learning representations, not doing autoregressive generation

If you wanted causal (can’t see ahead) attention, you’d need to pass a causal mask (not src_key_padding_mask):

# Causal mask (if you wanted it - you don't for JEPA)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len)
x = self.transformer(x, mask=causal_mask)

So the current setup is correct - the encoder sees all positions bidirectionally, which is what you want for learning JEPA-style representations.

prompt / response

But in this context you pass both the source and target through the encoder without a mask, so surely the context values will be able to see the target values and visa versa. I can see in https://github.com/jerber/lang-jepa/blob/main/src/encoder/train.py on line 136 and 144/145 that the source and target are masked, but in your version no masking is applied

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

You’re absolutely right - I made an error. The current code lets context tokens attend to target tokens during self-attention, which leaks information.

The fix is to apply the mask during the encoder forward pass:

def get_context_target_representations(context_encoder, target_encoder, input_ids, target_mask):
    """
    Get pooled context and target representations with proper masking.
    
    Context encoder only sees context positions.
    Target encoder only sees target positions.
    """
    context_mask = ~target_mask  # True for context positions
    
    # Context encoder: mask out target positions during attention
    context_hidden = context_encoder(input_ids, attention_mask=context_mask)
    context_hidden = context_hidden * context_mask.unsqueeze(-1).float()
    context_repr = context_hidden.sum(dim=1) / context_mask.sum(dim=1, keepdim=True).float()
    
    # Target encoder: mask out context positions during attention
    with torch.no_grad():
        target_hidden = target_encoder(input_ids, attention_mask=target_mask)
        target_hidden = target_hidden * target_mask.unsqueeze(-1).float()
        target_repr = target_hidden.sum(dim=1) / target_mask.sum(dim=1, keepdim=True).float()
    
    return context_repr, target_repr

Now: - context_encoder gets attention_mask=context_mask → can only attend to context positions - target_encoder gets attention_mask=target_mask → can only attend to target positions

The src_key_padding_mask in the encoder forward pass will prevent attention to the masked-out positions.

Full Class

prompt / response

Okay, I think I’m happy with this now. Can you now combine these previous steps into a full LightningModule?

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

import torch
import torch.nn as nn
import pytorch_lightning as pl

class TextEncoder(nn.Module):
    """Simple transformer encoder for text."""
    
    def __init__(self, vocab_size, embed_dim=256, num_heads=4, num_layers=4, max_length=512, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_length, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, input_ids, attention_mask=None):
        B, L = input_ids.shape
        positions = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
        
        x = self.token_embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)
        
        if attention_mask is not None:
            src_key_padding_mask = (attention_mask == 0)
        else:
            src_key_padding_mask = None
        
        x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
        return self.norm(x)


class LinearPredictor(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.LayerNorm(embed_dim),
        )
    
    def forward(self, x):
        return self.proj(x)


class LeJEPAText(pl.LightningModule):
    def __init__(
        self,
        vocab_size=50257,
        embed_dim=256,
        num_heads=4,
        num_layers=4,
        max_length=512,
        dropout=0.1,
        lr=1e-4,
        sigreg_weight=0.1,
        num_slices=1024,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Shared encoder for context and target
        self.encoder = TextEncoder(
            vocab_size, embed_dim, num_heads, num_layers, max_length, dropout
        )
        self.predictor = LinearPredictor(embed_dim)
        self.sigreg = SIGReg(num_slices=num_slices)
        
        self.lr = lr
        self.sigreg_weight = sigreg_weight
    
    def get_representations(self, input_ids, target_mask):
        """Get context and target representations with proper masking."""
        context_mask = ~target_mask
        
        # Context: only attend to context positions
        context_hidden = self.encoder(input_ids, attention_mask=context_mask)
        context_hidden = context_hidden * context_mask.unsqueeze(-1).float()
        context_repr = context_hidden.sum(dim=1) / context_mask.sum(dim=1, keepdim=True).float()
        
        # Target: only attend to target positions (no gradient)
        with torch.no_grad():
            target_hidden = self.encoder(input_ids, attention_mask=target_mask)
            target_hidden = target_hidden * target_mask.unsqueeze(-1).float()
            target_repr = target_hidden.sum(dim=1) / target_mask.sum(dim=1, keepdim=True).float()
        
        return context_repr, target_repr
    
    def forward(self, input_ids, target_mask):
        context_repr, target_repr = self.get_representations(input_ids, target_mask)
        predicted_repr = self.predictor(context_repr)
        return predicted_repr, target_repr, context_repr
    
    def compute_loss(self, batch):
        input_ids = batch["input_ids"]
        target_mask = batch["target_mask"]
        
        predicted_repr, target_repr, context_repr = self(input_ids, target_mask)
        
        # MSE loss between predicted and actual target representations
        prediction_loss = nn.functional.mse_loss(predicted_repr, target_repr)
        
        # SIGReg on context representations to prevent collapse
        sigreg_loss = self.sigreg(context_repr)
        
        total_loss = prediction_loss + self.sigreg_weight * sigreg_loss
        
        return total_loss, prediction_loss, sigreg_loss, context_repr
    
    def training_step(self, batch, batch_idx):
        total_loss, pred_loss, sigreg_loss, context_repr = self.compute_loss(batch)
        
        # Monitor for collapse
        repr_std = context_repr.std(dim=0).mean()
        
        self.log("train/loss", total_loss, prog_bar=True)
        self.log("train/pred_loss", pred_loss)
        self.log("train/sigreg_loss", sigreg_loss)
        self.log("train/repr_std", repr_std, prog_bar=True)
        
        return total_loss
    
    def validation_step(self, batch, batch_idx):
        total_loss, pred_loss, sigreg_loss, context_repr = self.compute_loss(batch)
        
        repr_std = context_repr.std(dim=0).mean()
        
        self.log("val/loss", total_loss, prog_bar=True)
        self.log("val/pred_loss", pred_loss)
        self.log("val/sigreg_loss", sigreg_loss)
        self.log("val/repr_std", repr_std)
        
        return total_loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

Usage:

model = LeJEPAText(vocab_size=50257, embed_dim=256, num_heads=4, num_layers=4)

# Test with a batch
batch = next(iter(dm.train_dataloader()))
loss = model.training_step(batch, 0)
print(f"Loss: {loss.item():.4f}")

Key features: - repr_std logged to monitor collapse (if it drops near 0, representations are collapsing) - torch.no_grad() on target to prevent trivial solutions - SIGReg on context representations for regularization

Initiate and Check Model

import torch
import torch.nn as nn
import pytorch_lightning as pl

class TextEncoder(nn.Module):
    """Simple transformer encoder for text."""
    
    def __init__(self, vocab_size, embed_dim=256, num_heads=4, num_layers=4, max_length=512, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_length, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, input_ids, attention_mask=None):
        B, L = input_ids.shape
        positions = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
        
        x = self.token_embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)
        
        if attention_mask is not None:
            src_key_padding_mask = (attention_mask == 0)
        else:
            src_key_padding_mask = None
        
        x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
        return self.norm(x)


class LinearPredictor(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.LayerNorm(embed_dim),
        )
    
    def forward(self, x):
        return self.proj(x)


class LeJEPAText(pl.LightningModule):
    def __init__(
        self,
        vocab_size=50257,
        embed_dim=256,
        num_heads=4,
        num_layers=4,
        max_length=512,
        dropout=0.1,
        lr=1e-4,
        sigreg_weight=0.1,
        num_slices=1024,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Shared encoder for context and target
        self.encoder = TextEncoder(
            vocab_size, embed_dim, num_heads, num_layers, max_length, dropout
        )
        self.predictor = LinearPredictor(embed_dim)
        self.sigreg = SIGReg(num_slices=num_slices)
        
        self.lr = lr
        self.sigreg_weight = sigreg_weight
    
    def get_representations(self, input_ids, target_mask):
        """Get context and target representations with proper masking."""
        context_mask = ~target_mask
        
        # Context: only attend to context positions
        context_hidden = self.encoder(input_ids, attention_mask=context_mask)
        context_hidden = context_hidden * context_mask.unsqueeze(-1).float()
        context_repr = context_hidden.sum(dim=1) / context_mask.sum(dim=1, keepdim=True).float()
        
        # Target: only attend to target positions (no gradient)
        with torch.no_grad():
            target_hidden = self.encoder(input_ids, attention_mask=target_mask)
            target_hidden = target_hidden * target_mask.unsqueeze(-1).float()
            target_repr = target_hidden.sum(dim=1) / target_mask.sum(dim=1, keepdim=True).float()
        
        return context_repr, target_repr
    
    def forward(self, input_ids, target_mask):
        context_repr, target_repr = self.get_representations(input_ids, target_mask)
        predicted_repr = self.predictor(context_repr)
        return predicted_repr, target_repr, context_repr
    
    def compute_loss(self, batch):
        input_ids = batch["input_ids"]
        target_mask = batch["target_mask"]
        
        predicted_repr, target_repr, context_repr = self(input_ids, target_mask)
        
        # MSE loss between predicted and actual target representations
        prediction_loss = nn.functional.mse_loss(predicted_repr, target_repr)
        
        # SIGReg on context representations to prevent collapse
        sigreg_loss = self.sigreg(context_repr)
        
        total_loss = prediction_loss + self.sigreg_weight * sigreg_loss
        
        return total_loss, prediction_loss, sigreg_loss, context_repr
    
    def training_step(self, batch, batch_idx):
        total_loss, pred_loss, sigreg_loss, context_repr = self.compute_loss(batch)
        
        # Monitor for collapse
        repr_std = context_repr.std(dim=0).mean()
        
        self.log("train/loss", total_loss, prog_bar=True)
        self.log("train/pred_loss", pred_loss)
        self.log("train/sigreg_loss", sigreg_loss)
        self.log("train/repr_std", repr_std, prog_bar=True)
        
        return total_loss
    
    def validation_step(self, batch, batch_idx):
        total_loss, pred_loss, sigreg_loss, context_repr = self.compute_loss(batch)
        
        repr_std = context_repr.std(dim=0).mean()
        
        self.log("val/loss", total_loss, prog_bar=True)
        self.log("val/pred_loss", pred_loss)
        self.log("val/sigreg_loss", sigreg_loss)
        self.log("val/repr_std", repr_std)
        
        return total_loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)
model = LeJEPAText(vocab_size=50257, embed_dim=64, num_heads=2, num_layers=2)

num_params = sum([p.numel() for p in model.parameters()])
print("Number of Parameters:", num_params)
Number of Parameters: 3353600
%%time
context_repr, target_repr = model.get_representations(**batch)
context_repr, target_repr
CPU times: user 8.3 s, sys: 844 ms, total: 9.15 s
Wall time: 385 ms
(tensor([[-0.5128, -0.0883,  0.1644,  ...,  0.2210,  0.1130,  0.1568],
         [-0.5430, -0.2594, -0.1270,  ...,  0.1448, -0.0542,  0.1635],
         [-0.3083,  0.1464, -0.0937,  ...,  0.0654,  0.2852, -0.1501],
         ...,
         [-0.4473, -0.1915, -0.1485,  ...,  0.2285, -0.2723, -0.1203],
         [-0.4323, -0.1273,  0.0560,  ...,  0.1686, -0.0838,  0.1050],
         [-0.4209,  0.0393, -0.0346,  ...,  0.1213, -0.0704,  0.1748]],
        grad_fn=<DivBackward0>),
 tensor([[-0.4988, -0.0431,  0.3527,  ...,  0.1180,  0.1283,  0.0995],
         [-0.4808, -0.0066, -0.0075,  ...,  0.1063,  0.0564,  0.1553],
         [-0.3580, -0.0674,  0.0278,  ...,  0.1353,  0.0917,  0.1066],
         ...,
         [-0.4294,  0.0062, -0.1615,  ...,  0.3190, -0.1362,  0.2172],
         [-0.4899, -0.2853, -0.0208,  ...,  0.0021,  0.0751,  0.3868],
         [-0.5488,  0.0648,  0.0065,  ..., -0.0263,  0.0421,  0.1304]]))
%%time
predicted_repr = model.predictor(context_repr)
predicted_repr
CPU times: user 38.5 ms, sys: 0 ns, total: 38.5 ms
Wall time: 1.06 ms
tensor([[ 0.5477, -0.7017, -0.2800,  ...,  1.4106, -0.6975, -0.2430],
        [-0.1500, -1.1685,  0.4107,  ...,  0.1386, -1.0401,  0.6272],
        [ 1.0390, -1.0849, -0.7849,  ...,  0.4333, -0.6909,  0.9622],
        ...,
        [-0.3853, -1.1286,  0.0256,  ..., -0.0637, -0.9749,  0.4740],
        [ 0.5579, -1.2967, -0.4796,  ...,  0.5714, -0.5915,  0.4055],
        [ 0.4296, -0.5910, -0.1715,  ...,  0.2881, -1.1545,  0.4743]],
       grad_fn=<NativeLayerNormBackward0>)
%%time
prediction_loss = nn.functional.mse_loss(predicted_repr, target_repr)
prediction_loss
CPU times: user 489 μs, sys: 20 μs, total: 509 μs
Wall time: 348 μs
tensor(1.1227, grad_fn=<MseLossBackward0>)
%%time
sigreg_loss = model.sigreg(context_repr)
sigreg_loss
CPU times: user 38.1 ms, sys: 0 ns, total: 38.1 ms
Wall time: 2.25 ms
tensor(0.8375, grad_fn=<MeanBackward0>)

Train

Now, lets use a lightning trainer and fit the model

🤖Reply🤖

🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠🧠

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

# Initialize data and model
dm = LeJEPATextDataModule(batch_size=32, max_length=256, min_span_length=32)
model = LeJEPAText(vocab_size=50257, embed_dim=64, num_heads=2, num_layers=2, lr=1e-4)

# Trainer
trainer = Trainer(
    max_epochs=10,
    accelerator="auto",
    callbacks=[
        RichProgressBar(),
        ModelCheckpoint(monitor="val/loss", mode="min", save_top_k=1),
    ],
    log_every_n_steps=10,
)

# Train
trainer.fit(model, dm)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

# Initialize data and model
dm = LeJEPATextDataModule(batch_size=32, max_length=256, min_span_length=32)
model = LeJEPAText(vocab_size=50257, embed_dim=64, num_heads=2, num_layers=2, max_length=256, lr=1e-4)

# Trainer
trainer = Trainer(
    max_epochs=2,
    accelerator="auto",
    callbacks=[
        RichProgressBar(),
        ModelCheckpoint(monitor="val/loss", mode="min", save_top_k=1),
    ],
    log_every_n_steps=50,
    val_check_interval=50
)

# Train
trainer.fit(model, dm)
Epoch 0/1  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/292 0:00:00 • -:--:-- 0.00it/s v_num: 5.000 val/loss: 1.134
model = LeJEPAText.load_from_checkpoint("lightning_logs/version_4/checkpoints/epoch=1-step=542.ckpt")

representation = model.encoder(batch["input_ids"]).mean(dim=1)
representation.shape
representation.mean(dim=1), representation.std(dim=1)

Conclude

Above we have a minimal implementation of LeJEPA on text. This has definately helped with my understanding of how JEPA architectures work, and I think could provide a good foundation for further implementations of LeJEPA for text domains.

There would be quite a few things I would want to do next, including:

  • Scale up the model and dataset.
  • Add proper logging and training monitoring.
  • Downstream evaluations on text classification datasets.
  • Add a decoder and use this for text generation.

Building this has been a pain. A couple of key issues:

  • The LLM doesn’t understand JEPA well and has made basic mistakes in the implementation. Most of the prompts above have come from multiple revisions, and my own rewrites after double checking other implementations and the LeJEPA paper. I suppose because it’s a relatively new technique that it isn’t widely seen in the training data.
  • solveit keeps crashing as I’m running. I think this is just because compute is limited and it’s running out of memory. In the future for these kinds of tasks I’ll want to use a GPU, which will either involve working locally or working out to SSH into a cluster while in solveit (potential follow up post!).

As a next step to this, I’m keep to implement this more comprehensively. I’ve relied quite a lot on this repo, and it could be interesting to try and add LeJEPA in here, while potentially also adding evaluation so that we can compare LeJEPA to the repos original implementation, and with other text encoders.